package main

import (
	"bytes"
	"flag"
	"fmt"
	"io/ioutil"
	"log"
	"os"
	"regexp"
	"strings"

	"github.com/projectdiscovery/nvd"
	"github.com/projectdiscovery/sliceutil"
	"github.com/projectdiscovery/stringsutil"
	"gopkg.in/yaml.v3"

	"github.com/projectdiscovery/nuclei/v2/pkg/catalog"
)

const (
	yamlIndentSpaces = 2
)

var (
	input       = flag.String("i", "", "Templates to annotate")
	templateDir = flag.String("d", "", "Custom template directory for update")
)

func main() {
	flag.Parse()

	if *input == "" || *templateDir == "" {
		log.Fatalf("invalid input, see -h\n")
	}

	if err := process(); err != nil {
		log.Fatalf("could not process: %s\n", err)
	}
}

func process() error {
	tempDir, err := ioutil.TempDir("", "nuclei-nvd-%s")
	if err != nil {
		return err
	}
	defer os.RemoveAll(tempDir)

	client, err := nvd.NewClient(tempDir)
	if err != nil {
		return err
	}
	catalog := catalog.New(*templateDir)

	paths, err := catalog.GetTemplatePath(*input)
	if err != nil {
		return err
	}
	for _, path := range paths {
		data, err := os.ReadFile(path)
		if err != nil {
			return err
		}
		getCVEData(client, path, string(data))
	}
	return nil
}

var (
	idRegex       = regexp.MustCompile("id: ([C|c][V|v][E|e]-[0-9]+-[0-9]+)")
	severityRegex = regexp.MustCompile(`severity: ([a-z]+)`)
)

const maxReferenceCount = 5

func getCVEData(client *nvd.Client, filePath, data string) {
	matches := idRegex.FindAllStringSubmatch(data, 1)
	if len(matches) == 0 {
		return
	}
	cveName := matches[0][1]

	severityMatches := severityRegex.FindAllStringSubmatch(data, 1)
	if len(severityMatches) == 0 {
		return
	}
	severityValue := severityMatches[0][1]

	cveItem, err := client.FetchCVE(cveName)
	if err != nil {
		log.Printf("Could not fetch cve %s: %s\n", cveName, err)
		return
	}
	var cweID []string
	for _, problemData := range cveItem.CVE.Problemtype.ProblemtypeData {
		for _, description := range problemData.Description {
			cweID = append(cweID, description.Value)
		}
	}
	cvssScore := cveItem.Impact.BaseMetricV3.CvssV3.BaseScore
	cvssMetrics := cveItem.Impact.BaseMetricV3.CvssV3.VectorString

	// Perform some hacky string replacement to place the metadata in templates
	infoBlockIndexData := data[strings.Index(data, "info:"):]
	requestsIndex := strings.Index(infoBlockIndexData, "requests:")
	networkIndex := strings.Index(infoBlockIndexData, "network:")
	variablesIndex := strings.Index(infoBlockIndexData, "variables:")
	if requestsIndex == -1 && networkIndex == -1 && variablesIndex == -1 {
		return
	}
	if networkIndex != -1 {
		requestsIndex = networkIndex
	}
	if variablesIndex != -1 {
		requestsIndex = variablesIndex
	}
	infoBlockData := infoBlockIndexData[:requestsIndex]
	infoBlockClean := strings.TrimRight(infoBlockData, "\n")

	infoBlock := InfoBlock{}
	err = yaml.Unmarshal([]byte(data), &infoBlock)
	if err != nil {
		log.Printf("Could not unmarshal info block: %s\n", err)
	}

	var changed bool
	if newSeverity := isSeverityMatchingCvssScore(severityValue, cvssScore); newSeverity != "" {
		changed = true
		infoBlock.Info.Severity = newSeverity
		fmt.Printf("Adjusting severity for %s from %s=>%s (%.2f)\n", filePath, severityValue, newSeverity, cvssScore)
	}
	isCvssEmpty := cvssScore == 0 || cvssMetrics == ""
	hasCvssChanged := infoBlock.Info.Classification.CvssScore != cvssScore || cvssMetrics != infoBlock.Info.Classification.CvssMetrics
	if !isCvssEmpty && hasCvssChanged {
		changed = true
		infoBlock.Info.Classification.CvssMetrics = cvssMetrics
		infoBlock.Info.Classification.CvssScore = cvssScore
		infoBlock.Info.Classification.CveId = cveName
		if len(cweID) > 0 && (cweID[0] != "NVD-CWE-Other" && cweID[0] != "NVD-CWE-noinfo") {
			infoBlock.Info.Classification.CweId = strings.Join(cweID, ",")
		}
	}
	// If there is no description field, fill the description from CVE information
	hasDescriptionData := len(cveItem.CVE.Description.DescriptionData) > 0
	isDescriptionEmpty := infoBlock.Info.Description == ""
	if isDescriptionEmpty && hasDescriptionData {
		changed = true
		// removes all new lines
		description := stringsutil.ReplaceAny(cveItem.CVE.Description.DescriptionData[0].Value, "", "\n", "\\", "'", "\t")
		description += "\n"
		infoBlock.Info.Description = description
	}

	// we are unmarshaling info block to have valid data
	var referenceDataURLs []string
	for _, reference := range cveItem.CVE.References.ReferenceData {
		referenceDataURLs = append(referenceDataURLs, reference.URL)
	}
	hasReferenceData := len(cveItem.CVE.References.ReferenceData) > 0
	areCveReferencesContained := sliceutil.ContainsItems(infoBlock.Info.Reference, referenceDataURLs)
	referencesCount := len(infoBlock.Info.Reference)
	if hasReferenceData && !areCveReferencesContained {
		changed = true
		for _, reference := range cveItem.CVE.References.ReferenceData {
			referencesCount++
			if referencesCount >= maxReferenceCount {
				break
			}
			infoBlock.Info.Reference = append(infoBlock.Info.Reference, reference.URL)
		}
		infoBlock.Info.Reference = sliceutil.PruneEmptyStrings(sliceutil.Dedupe(infoBlock.Info.Reference))
	}

	var newInfoBlock bytes.Buffer
	yamlEncoder := yaml.NewEncoder(&newInfoBlock)
	yamlEncoder.SetIndent(yamlIndentSpaces)
	err = yamlEncoder.Encode(infoBlock)
	if err != nil {
		log.Printf("Could not marshal info block: %s\n", err)
		return
	}
	newInfoBlockData := strings.TrimSuffix(newInfoBlock.String(), "\n")

	newTemplate := strings.ReplaceAll(data, infoBlockClean, newInfoBlockData)
	if changed {
		_ = ioutil.WriteFile(filePath, []byte(newTemplate), 0644)
		fmt.Printf("Wrote updated template to %s\n", filePath)
	}
}

func isSeverityMatchingCvssScore(severity string, score float64) string {
	if score == 0.0 {
		return ""
	}
	var expected string

	if score >= 0.1 && score <= 3.9 {
		expected = "low"
	} else if score >= 4.0 && score <= 6.9 {
		expected = "medium"
	} else if score >= 7.0 && score <= 8.9 {
		expected = "high"
	} else if score >= 9.0 && score <= 10.0 {
		expected = "critical"
	}
	if expected != "" && expected != severity {
		return expected
	}
	return ""
}

// Cloning struct from nuclei as we don't want any validation
type InfoBlock struct {
	Info TemplateInfo `yaml:"info"`
}

type TemplateClassification struct {
	CvssMetrics string  `yaml:"cvss-metrics,omitempty"`
	CvssScore   float64 `yaml:"cvss-score,omitempty"`
	CveId       string  `yaml:"cve-id,omitempty"`
	CweId       string  `yaml:"cwe-id,omitempty"`
}

type TemplateInfo struct {
	Name           string                 `yaml:"name"`
	Author         string                 `yaml:"author"`
	Severity       string                 `yaml:"severity"`
	Description    string                 `yaml:"description,omitempty"`
	Reference      []string               `yaml:"reference,omitempty"`
	Remediation    string                 `yaml:"remediation,omitempty"`
	Classification TemplateClassification `yaml:"classification,omitempty"`
	Metadata       map[string]string      `yaml:"metadata,omitempty"`
	Tags           string                 `yaml:"tags,omitempty"`
}
